from deepsearcher.agent import generate_sub_queries, generate_gap_queries, generate_final_answer
from deepsearcher.agent.search_vdb import search_chunks_from_vectordb
from deepsearcher.vector_db.base import deduplicate_results
# from deepsearcher.configuration import vector_db, embedding_model, llm
from deepsearcher import configuration
from deepsearcher.tools import log
import asyncio


async def async_query(original_query: str, max_iter: int = 3) -> str:
    
    log.color_print(f"<query> {original_query} </query>\n")
    all_search_res = []
    all_sub_queries = []

    ### SUB QUERIES ###
    sub_queries = generate_sub_queries(original_query)
    if not sub_queries:
        log.color_print("No sub queries were generated by the LLM. Exiting.")
        return
    else:
        log.color_print(f"<think> Break down the original query into new sub queries: {sub_queries}</think>\n")
    all_sub_queries.extend(sub_queries)
    sub_gap_queries = sub_queries

    for iter in range(max_iter):
        log.color_print(f">> Iteration: {iter + 1}\n")
        search_res_from_vectordb = []
        search_res_from_internet = []  # TODO
        
        # Create all search tasks
        search_tasks = [search_chunks_from_vectordb(query, sub_gap_queries) for query in sub_gap_queries]
        # Execute all tasks in parallel and wait for results
        search_results = await asyncio.gather(*search_tasks)
        # Merge all results
        for result in search_results:
            search_res_from_vectordb.extend(result)
            
        search_res_from_vectordb = deduplicate_results(search_res_from_vectordb)
        # search_res_from_internet = deduplicate_results(search_res_from_internet)
        all_search_res.extend(search_res_from_vectordb + search_res_from_internet)

        ### REFLECTION & GET GAP QUERIES ###
        log.color_print("<think> Reflecting on the search results... </think>\n")
        sub_gap_queries = generate_gap_queries(original_query, all_sub_queries, all_search_res)
        if not sub_gap_queries:
            log.color_print("<think> No new search queries were generated. Exiting. </think>\n")
            break
        else:
            log.color_print(f"<think> New search queries for next iteration: {sub_gap_queries} </think>\n")
            all_sub_queries.extend(sub_gap_queries)

    ### GENERATE FINAL ANSWER ###
    log.color_print("<think> Generating final answer... </think>\n")
    all_search_res = deduplicate_results(all_search_res)
    final_answer = generate_final_answer(original_query, all_sub_queries, all_search_res)
    log.color_print("\n==== FINAL ANSWER====\n")
    log.color_print(final_answer)
    return final_answer



def naive_rag_query(query: str, collection: str=None, top_k=10):
    vector_db = configuration.vector_db
    embedding_model = configuration.embedding_model
    llm = configuration.llm

    if not collection:
        retrieval_res = []
        collections = [col_info.collection_name for col_info in vector_db.list_collections()]
        for collection in collections:
            retrieval_res_col = vector_db.search_data(collection=collection, vector=embedding_model.embed_query(query), top_k=top_k//len(collections))
            retrieval_res.extend(retrieval_res_col)
        retrieval_res = deduplicate_results(retrieval_res)
    else:
        retrieval_res = vector_db.search_data(collection=collection, vector=embedding_model.embed_query(query), top_k=10)

    
    chunk_texts = []
    for chunk in retrieval_res:
        if "wider_text" in chunk.metadata:
            chunk_texts.append(chunk.metadata["wider_text"])
        else:
            chunk_texts.append(chunk.text)
    mini_chunk_str = ""
    for i, chunk in enumerate(chunk_texts):
        mini_chunk_str += f"""<chunk_{i}>\n{chunk}\n</chunk_{i}>\n"""
    
    summary_prompt = f"""You are a AI content analysis expert, good at summarizing content. Please summarize a specific and detailed answer or report based on the previous queries and the retrieved document chunks.

    Original Query: {query}
    Related Chunks: 
    {mini_chunk_str}
    """
    char_response = llm.chat([{"role": "user", "content": summary_prompt}])
    return char_response.content

# Add a wrapper function to support synchronous calls
def query(original_query: str, max_iter: int = 3) -> str:
    return asyncio.run(async_query(original_query, max_iter))